from typing import List


class Solution:
    def matrixReshape(self, mat: List[List[int]], r: int, c: int) -> List[List[int]]:
        r0, r1 = len(mat), len(mat[0])
        if r0 * r1 != r * c:
            return mat
        ans = []
        flat = [item for row in mat for item in row]
        for i in range(r):
            rowItem = []
            for j in range(c):
                rowItem.append(flat[i * c + j])
            # ans.append(flat[i * c: (i + 1) * c])
            ans.append(rowItem)
        return ans